import os
import shutil
import random
from pathlib import Path

random.seed(1)


def split_dataset(source_dir, target_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """
    Split subdirectories in the source directory into train/val/test sets according to specified ratios.

    Args:
        source_dir (str): Path to the source directory (containing class subdirectories)
        target_dir (str): Path to the target main directory (will create train/val/test subdirectories)
        train_ratio (float): Ratio for training set
        val_ratio (float): Ratio for validation set
        test_ratio (float): Ratio for test set
    """
    # Create target directories
    train_dir = os.path.join(target_dir, 'train')
    val_dir = os.path.join(target_dir, 'val')
    test_dir = os.path.join(target_dir, 'test')

    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Get all class folders
    class_folders = [f for f in os.listdir(source_dir)
                     if os.path.isdir(os.path.join(source_dir, f))]
    random.shuffle(class_folders)  # Shuffle class order randomly

    total_classes = len(class_folders)

    # Calculate number of classes for each split
    num_train = int(total_classes * train_ratio)
    num_val = int(total_classes * val_ratio)
    num_test = total_classes - num_train - num_val

    print(f"Total classes: {total_classes}")
    print(f"Train classes: {num_train}")
    print(f"Val classes: {num_val}")
    print(f"Test classes: {num_test}")

    # Split dataset
    train_classes = class_folders[:num_train]
    val_classes = class_folders[num_train:num_train + num_val]
    test_classes = class_folders[num_train + num_val:]

    # Move folders to corresponding directories
    def move_classes(classes, dest_dir):
        for class_name in classes:
            src = os.path.join(source_dir, class_name)
            dst = os.path.join(dest_dir, class_name)
            if os.path.exists(dst):
                shutil.rmtree(dst)
            shutil.move(src, dst)
            print(f"Moved {class_name} to {dest_dir}")

    move_classes(train_classes, train_dir)
    move_classes(val_classes, val_dir)
    move_classes(test_classes, test_dir)

    print("Dataset splitting completed!")


if __name__ == "__main__":
    # Set paths
    main_dir = "" #your dir
    source_dir = os.path.join(main_dir, "derm")
    target_dir = main_dir

    # Check if source directory exists
    if not os.path.exists(source_dir):
        raise ValueError(f"Source directory {source_dir} does not exist!")

    # Execute dataset splitting
    split_dataset(source_dir, target_dir)